import ast
from collections import Counter, defaultdict
import functools
import random
import signal
from typing import Any, Dict, List, Literal, Optional, Tuple, Callable, TypeVar, TypeAlias

import matplotlib.pyplot as plt
import networkx as nx

from src.visualization import arc_cmap, arc_norm

GraphMode = Literal["spring", "planar", "circular", "kamada_kawai", "random", "shell", "spectral", "spiral"]


def ensure_bounds(*x: int, low: int = 1, high: int = 30) -> int:
    """Ensure that all the inputs are within bounds (default to between 1 and 30)."""
    if len(x) == 1:
        return max(low, min(x[0], high))
    else:
        return tuple(max(low, min(i, high)) for i in x)


def int_in_bounds(i: int) -> bool:
    return -30 < i < 900


def plot_task(task: list[dict] | dict, title: str = None, figsize_factor: float = 3) -> None:
    height = 2
    if isinstance(task, dict):
        task = [task]
    width = len(task)
    figure_size = (width * figsize_factor, height * figsize_factor)
    figure, axes = plt.subplots(height, width, figsize=figure_size, squeeze=False)
    for column, example in enumerate(task):
        axes[0, column].imshow(example["input"], cmap=arc_cmap, norm=arc_norm, origin="lower")
        axes[0, column].text(
            0.5,
            -0.02,
            "{}x{}".format(*example["input"].shape),
            transform=axes[0, column].transAxes,
            ha="center",
            va="top",
        )
        axes[1, column].imshow(example["output"], cmap=arc_cmap, norm=arc_norm, origin="lower")
        axes[1, column].text(
            0.5,
            -0.02,
            "{}x{}".format(*example["output"].shape),
            transform=axes[1, column].transAxes,
            ha="center",
            va="top",
        )
        axes[0, column].axis("off")
        axes[1, column].axis("off")

    if title is not None:
        figure.suptitle(title, fontsize=20)
    plt.tight_layout()
    plt.show()


def get_node_predecessors(G: nx.MultiDiGraph, node_id: int) -> List[int]:
    """Returns the list of predecessors of a node in a multiedge Directed Acyclic Graph (DAG)."""
    predecessors = []
    for k, v in G.pred[node_id].items():
        predecessors.extend(len(v) * [k])
    return predecessors


def visualize_dag(
    G: nx.Graph,
    mode: GraphMode = "spring",
    title: Optional[str] = None,
    font_size: int = 8,
) -> None:
    """Visualize the computational Directed Acyclic Graph."""
    if mode == "spring":
        pos = nx.spring_layout(G)
    elif mode == "planar":
        pos = nx.planar_layout(G)
    elif mode == "circular":
        pos = nx.circular_layout(G)
    elif mode == "kamada_kawai":
        pos = nx.kamada_kawai_layout
    elif mode == "random":
        pos = nx.random_layout(G)
    elif mode == "shell":
        pos = nx.shell_layout(G)
    elif mode == "spectral":
        pos = nx.spectral_layout(G)
    elif mode == "spiral":
        pos = nx.spiral_layout(G)
    else:
        raise ValueError(f"Invalid mode. Choose from {GraphMode}.")
    nx.draw(G, pos, with_labels=False, node_color="lightblue", edge_color="gray")
    labels = {node_id: node["primitive"] for node_id, node in G.nodes(data=True)}
    nx.draw_networkx_labels(G, pos, labels, font_size=font_size)
    if title is not None:
        plt.title(title)
    plt.show()


def is_grid(grid: Any) -> bool:
    """Check if the input is a valid grid in terms of structure and values."""
    if not isinstance(grid, tuple):
        return False
    return (
        all(
            isinstance(row, tuple)
            and all(isinstance(pixel, int) for pixel in row)
            and all(0 <= pixel <= 9 for pixel in row)
            and len(row) == len(grid[0])
            for row in grid
        )
        and 1 <= len(grid) <= 30
        and 1 <= len(grid[0]) <= 30
    )


class PrimitiveVisitor(ast.NodeVisitor):
    def __init__(self, primitives: Optional[set[str]] = None):
        self.primitive_count = Counter()
        self.const_mapping = {
            "ZERO": "const_0",
            "ONE": "const_1",
            "TWO": "const_2",
            "THREE": "const_3",
            "FOUR": "const_4",
            "FIVE": "const_5",
            "SIX": "const_6",
            "SEVEN": "const_7",
            "EIGHT": "const_8",
            "NINE": "const_9",
            "TEN": "const_10",
            "T": "const_true",
            "F": "const_false",
            "NEG_ONE": "const_neg1",
            "NEG_TWO": "const_neg2",
            "ORIGIN": "const_origin",
            "UNITY": "const_unity",
            "DOWN": "const_down",
            "RIGHT": "const_right",
            "UP": "const_up",
            "LEFT": "const_left",
            "NEG_UNITY": "const_neg_unity",
            "UP_RIGHT": "const_up_right",
            "DOWN_LEFT": "const_down_left",
            "ZERO_BY_TWO": "const_zero_by_two",
            "TWO_BY_ZERO": "const_two_by_zero",
            "TWO_BY_TWO": "const_two_by_two",
            "THREE_BY_THREE": "const_three_by_three",
        }
        self.primitives = primitives or set()

    def visit_Call(self, node):
        if isinstance(node.func, ast.Name) and (not self.primitives or node.func.id in self.primitives):
            self.primitive_count[node.func.id] += 1

        # Count functions used as arguments
        for arg in node.args:
            if isinstance(arg, ast.Name) and arg.id in self.primitives:
                self.primitive_count[f"const_{arg.id}"] += 1

        self.generic_visit(node)

    def visit_Name(self, node):
        if node.id in self.const_mapping:
            self.primitive_count[self.const_mapping[node.id]] += 1
        self.generic_visit(node)

    def visit_Constant(self, node):
        if isinstance(node.n, int) and 0 <= node.n <= 10:
            self.primitive_count[f"const_{node.n}"] += 1


def count_primitives_in_module(module_code: str, primitives: Optional[set[str]] = None) -> Dict[str, int]:
    tree = ast.parse(module_code)
    function_primitives: Dict[str, Dict[str, int]] = {}

    for node in tree.body:
        if isinstance(node, ast.FunctionDef):
            visitor = PrimitiveVisitor(primitives)
            visitor.visit(node)
            primitive_counts = dict(visitor.primitive_count)
            for re_arc_const, constant_primitive in visitor.const_mapping.items():
                if re_arc_const in primitive_counts:
                    primitive_counts[constant_primitive] = primitive_counts.pop(re_arc_const)
            function_primitives[node.name] = primitive_counts

    total_count = Counter()
    for func_counts in function_primitives.values():
        total_count.update(func_counts)
    return {k: v for k, v in sorted(dict(total_count).items(), key=lambda kv: kv[1], reverse=True)}


class PrimitiveInputVisitor(ast.NodeVisitor):
    def __init__(self, primitives: Optional[set[str]] = None):
        self.primitive_inputs = defaultdict(list)
        self.const_mapping = {
            "ZERO": "const_0",
            "ONE": "const_1",
            "TWO": "const_2",
            "THREE": "const_3",
            "FOUR": "const_4",
            "FIVE": "const_5",
            "SIX": "const_6",
            "SEVEN": "const_7",
            "EIGHT": "const_8",
            "NINE": "const_9",
            "TEN": "const_10",
            "T": "const_true",
            "F": "const_false",
            "NEG_ONE": "const_neg1",
            "NEG_TWO": "const_neg2",
            "ORIGIN": "const_origin",
            "UNITY": "const_unity",
            "DOWN": "const_down",
            "RIGHT": "const_right",
            "UP": "const_up",
            "LEFT": "const_left",
            "NEG_UNITY": "const_neg_unity",
            "UP_RIGHT": "const_up_right",
            "DOWN_LEFT": "const_down_left",
            "ZERO_BY_TWO": "const_zero_by_two",
            "TWO_BY_ZERO": "const_two_by_zero",
            "TWO_BY_TWO": "const_two_by_two",
            "THREE_BY_THREE": "const_three_by_three",
        }
        self.primitives = primitives or set()
        self.var_mapping = {"I": "toinput"}

    def visit_Assign(self, node):
        assert len(node.targets) == 1
        if isinstance(node.targets[0], ast.Name) and isinstance(node.value, ast.Call):
            self.var_mapping[node.targets[0].id] = node.value.func.id
        self.generic_visit(node)

    def visit_Return(self, node):
        # Treat return statement as if it were a call to "tooutput"
        assert isinstance(node.value, ast.Name)
        assert node.value.id in self.var_mapping
        tooutput_input = self.var_mapping[node.value.id]
        while tooutput_input in self.var_mapping:
            tooutput_input = self.var_mapping[tooutput_input]
        self.primitive_inputs["tooutput"].append((tooutput_input,))
        self.generic_visit(node)

    def visit_Call(self, node):
        if isinstance(node.func, ast.Name) and (not self.primitives or node.func.id in self.primitives):
            inputs = []
            for arg in node.args:
                if isinstance(arg, ast.Name):
                    if arg.id in self.var_mapping:
                        while arg.id in self.var_mapping:
                            arg.id = self.var_mapping[arg.id]
                        inputs.append(arg.id)
                    elif arg.id in self.const_mapping:
                        inputs.append(self.const_mapping[arg.id])
                    else:
                        inputs.append(f"const_{arg.id}")
                elif isinstance(arg, ast.Constant):
                    assert isinstance(arg.value, int) and 0 <= arg.value <= 10
                    inputs.append(f"const_{arg.value}")
            self.primitive_inputs[node.func.id].append(tuple(inputs))
        self.generic_visit(node)


def count_primitive_inputs_in_module(
    module_code: str, primitives: Optional[set[str]] = None
) -> Dict[str, List[Tuple[Dict[str, int], ...]]]:
    tree = ast.parse(module_code)
    visitor = PrimitiveInputVisitor(primitives)
    visitor.visit(tree)

    result = {}
    for primitive, input_lists in visitor.primitive_inputs.items():
        num_inputs = len(input_lists[0])
        input_counts = tuple(defaultdict(int) for _ in range(num_inputs))
        for inputs in input_lists:
            assert len(inputs) == num_inputs
            for i, input_name in enumerate(inputs):
                input_counts[i][input_name] += 1
        result[primitive] = tuple(dict(input_count) for input_count in input_counts)
    return result


class EMA:
    def __init__(self, start: float, smoothing: float = 0.3, return_inverse: bool = False, eps: float = 1e-8):
        self.alpha = max(eps, min(smoothing, 1 - eps))
        self.eps = eps
        self.last_value = start
        self.diff = 0
        self.calls = 0
        self.return_inverse = return_inverse

    def __call__(self, x: float) -> float:
        beta = 1 - self.alpha
        self.diff = self.alpha * (x - self.last_value) + beta * self.diff
        self.last_value = x
        self.calls += 1
        if self.return_inverse:
            return (1 - beta**self.calls) / (self.diff + self.eps)
        else:
            return self.diff / (1 - beta**self.calls)


T = TypeVar("T")
OptionalRandomState: TypeAlias = Optional[Tuple]


def run_with_timeout(
    func: Callable[[], T], timeout: int
) -> Callable[[OptionalRandomState], Tuple[Optional[T], OptionalRandomState, Optional[Exception]]]:
    def timeout_handler(signum, frame):
        raise TimeoutError(f"Function execution timed out after {timeout} seconds")

    @functools.wraps(func)
    def wrapper(random_state: OptionalRandomState = None):
        result = None
        exception = None

        # Store the original SIGINT and SIGALRM handlers
        original_sigint_handler = signal.getsignal(signal.SIGINT)
        original_sigalrm_handler = signal.getsignal(signal.SIGALRM)

        def handle_interrupt(signum, frame):
            # Restore the original handlers
            signal.signal(signal.SIGINT, original_sigint_handler)
            signal.signal(signal.SIGALRM, original_sigalrm_handler)
            raise KeyboardInterrupt()

        try:
            # Set up the alarm signal
            signal.signal(signal.SIGALRM, timeout_handler)
            # Set up the interrupt handler
            signal.signal(signal.SIGINT, handle_interrupt)
            signal.alarm(timeout)

            if random_state is not None:
                random.setstate(random_state)
            result = func()
            random_state = random.getstate() if random_state is not None else None
        except KeyboardInterrupt as e:
            raise e
        except TimeoutError as e:
            # Timeout occurred
            if random_state is not None:
                random.setstate(random_state)
                random.seed(random.randint(0, 2**32 - 1))
                random_state = random.getstate()
            exception = e
        except Exception as e:
            # Some other exception occurred
            random_state = random.getstate() if random_state is not None else None
            exception = e
        finally:
            # Cancel the alarm and restore the original SIGINT and SIGALRM handlers
            signal.alarm(0)
            signal.signal(signal.SIGINT, original_sigint_handler)
            signal.signal(signal.SIGALRM, original_sigalrm_handler)
        return result, random_state, exception  # Move return here

    return wrapper
